import torch
import tltorch
import tntorch
import collections

import math

import torch.nn.functional as F
import tensorly as tly

tly.set_backend('pytorch')
torch.set_default_dtype(torch.float32)

import numpy as np

from einops import rearrange

from .tucker_conv_base import _ConvNd
from .tucker_conv_base import _pair

collections.Iterable = collections.abc.Iterable
collections.Mapping = collections.abc.Mapping
collections.MutableSet = collections.abc.MutableSet
collections.MutableMapping = collections.abc.MutableMapping


class Conv2d_tucker_adaptive(_ConvNd):  # adaptive rank tucker layer
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size, 
            stride,  
            padding, 
            dilation, 
            groups: int = 1,
            bias: bool = True,
            padding_mode: str = "zeros",
            dtype=None,
            device=None,
            low_rank_percent=None,
            tau=0.1
    ) -> None:
        """
        Initializer for the convolutional low rank layer (filterwise), extention of the classical Pytorch's convolutional layer.
        INPUTS:
        in_channels: number of input channels (Pytorch's standard)
        out_channels: number of output channels (Pytorch's standard)
        kernel_size : kernel_size for the convolutional filter (Pytorch's standard)
        dilation : dilation of the convolution (Pytorch's standard)
        padding : padding of the convolution (Pytorch's standard)
        stride : stride of the filter (Pytorch's standard)
        bias  : flag variable for the bias to be included (Pytorch's standard)
        step : string variable ('K','L' or 'S') for which forward phase to use
        rank : rank variable, None if the layer has to be treated as a classical Pytorch Linear layer (with weight and bias). If
                it is an int then it's either the starting rank for adaptive or the fixed rank for the layer.
        fixed : flag variable, True if the rank has to be fixed (KLS training on this layer)
        load_weights : variables to load (Pytorch standard, to finish)
        dtype : Type of the tensors (Pytorch standard, to finish)
        """

        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=list(kernel_size) if isinstance(kernel_size, (list, tuple)) else [kernel_size] * 2,
            stride=stride,
            padding=list(padding) if isinstance(padding, (list, tuple)) else [padding] * 2,
            dilation=dilation,
            transposed=False,
            output_padding=_pair(0),
            groups=groups,
            bias=bias,
            padding_mode=padding_mode,
            device=device,
            dtype=dtype,
            # ====================== DLRT params =========================================
            low_rank_percent=low_rank_percent,
            tau=tau
        )

        self.device, self.dtype = device, dtype
        self.tau = tau
        self.fixed = False
        self.dims = [self.out_channels, self.in_channels] + self.kernel_size
        self.rmax = [d for d in self.dims]
        self.rank = self.rmax
        self.dynamic_rank = self.rank[::]
        ##### update starting maximal ranks according to the condition ri <= prod(r-i)
        self.dynamic_rank = [min(d, self.get_r_mod_i(i)) for i, d in enumerate(self.dims)]  #### new maximal ranks
        self.rmax = [min(d, self.get_r_mod_i(i)) for i, d in enumerate(self.dims)]  #### new rmax
        #####

        if self.bias is not None:
            self.bias = torch.nn.Parameter(
                torch.zeros(self.out_channels, requires_grad=False, **factory_kwargs),
                requires_grad=False,
            )

        self.C = torch.nn.Parameter(torch.empty(size=[2 * s for s in self.dynamic_rank]))  # core tensor
        self.Us = [torch.empty(size=(d, r)) for d, r in zip(self.dims, self.rmax)]  # basis
        self.Us = torch.nn.ParameterList(
            [torch.nn.Parameter(U[:, :r], requires_grad=False).to(device) for r, U in zip(self.rmax, self.Us)])

        self.Ks = torch.nn.ParameterList(
            [torch.nn.Parameter(torch.empty(d, r)) for (d, r) in zip(self.dims, self.rmax)])  # K variables

        self.U_hats = torch.nn.ParameterList([torch.nn.Parameter(torch.empty(d, 2 * r), requires_grad=False)
                                              for (d, r) in zip(self.dims, self.rmax)])  # augmented basis

        self.Qst = torch.nn.ParameterList()  # core for the K step
        for i, r in enumerate(self.rmax):
            other_dims = [d for j, d in enumerate(self.rmax) if j != i]
            self.Qst.append(torch.nn.Parameter(torch.empty(r, *other_dims), requires_grad=False))

        self.M_hats = torch.nn.ParameterList(
            [torch.nn.Parameter(torch.empty(2 * r, r), requires_grad=False) for r in self.rmax])  # stiffness matrices

        self.reset_parameters()

    @torch.no_grad()
    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.C, a=math.sqrt(5))
        for i in range(len(self.dims)):
            torch.nn.init.kaiming_uniform_(self.Us[i], a=math.sqrt(5))
            torch.nn.init.kaiming_uniform_(self.Qst[i], a=math.sqrt(5))
            torch.nn.init.kaiming_uniform_(self.U_hats[i], a=math.sqrt(5))
            torch.nn.init.kaiming_uniform_(self.Ks[i], a=math.sqrt(5))
            torch.nn.init.kaiming_uniform_(self.M_hats[i], a=math.sqrt(5))

        if self.bias is not None:
            weight = torch.empty(
                (self.out_channels, self.in_channels // self.groups, *self.kernel_size),
                device=self.device,
                dtype=self.dtype,
            )
            torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5))
            fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(weight)
            if fan_in != 0:
                bound = 1 / math.sqrt(fan_in)
                torch.nn.init.uniform_(self.bias, -bound, bound)
            del weight

        self.basic_number_weights = self.in_channels * (self.out_channels // self.groups + self.kernel_size_number)

    def forward(self, input):
        """
        forward phase for the convolutional layer. It has to contain the three different
        phases for the steps 'K','L' and 'S' in order to be optimizable using dlrt.
        Every step is rewritten in terms of the tucker decomposition of the kernel tensor.
        """

        if self.step == 1:  # first mode  (K_0)

            r = self.dynamic_rank[0]
            Us = [u[:, :self.dynamic_rank[j]] for j, u in enumerate(self.Us) if j != 0]
            other_ranks = [self.dynamic_rank[j] for j, u in enumerate(self.Us) if j != 0]
            Qst = self.Qst[0][:r, :other_ranks[0], :other_ranks[1], :other_ranks[2]]
            Us = [torch.eye(n=Qst.shape[0],device = Qst.device)] + Us
            first_conv = tltorch.functional.tucker_conv(input, tucker_tensor=tltorch.TuckerTensor(Qst, Us,
                                                                                                  rank=self.dynamic_rank),
                                                        stride=self.stride, padding=self.padding,
                                                        dilation=self.dilation)
            result = torch.einsum('fl,nluv->nfuv', self.Ks[0][:, :r], first_conv)
            if self.bias is not None:
                result += self.bias.view(self.bias.shape[0], 1, 1)

            return result


        elif self.step == 2:  # second mode (K_1)

            r = self.dynamic_rank[1]
            Us = [u[:, :self.dynamic_rank[j]] for j, u in enumerate(self.Us) if j != 1]
            other_ranks = [self.dynamic_rank[j] for j, u in enumerate(self.Us) if j != 1]
            result = torch.einsum('cl,ncxy->nlxy', self.Ks[1][:, :r], input)
            Qst = self.Qst[1][:r, :other_ranks[0], :other_ranks[1], :other_ranks[2]]
            Qst = rearrange(Qst, 'l f j k -> f l j k')
            Us.insert(1, torch.eye(n=Qst.shape[1],device = Qst.device))
            result = tltorch.functional.tucker_conv(result,
                                                    tucker_tensor=tltorch.TuckerTensor(Qst, Us, rank=self.dynamic_rank),
                                                    stride=self.stride, padding=self.padding, dilation=self.dilation,
                                                    bias=self.bias)

            return result


        elif self.step == 3:  # third mode (K_2)

            r = self.dynamic_rank[2]
            Us = [u[:, :self.dynamic_rank[j]] for j, u in enumerate(self.Us) if j != 2]
            other_ranks = [self.dynamic_rank[j] for j, u in enumerate(self.Us) if j != 2]
            Qst = self.Qst[2][:r, :other_ranks[0], :other_ranks[1], :other_ranks[2]]
            Qst = rearrange(Qst, 'j f c k -> f c j k')
            Us.insert(2, self.Ks[2][:, :self.dynamic_rank[2]])
            result = tltorch.functional.tucker_conv(input,
                                                    tucker_tensor=tltorch.TuckerTensor(Qst, Us, rank=self.dynamic_rank),
                                                    stride=self.stride, padding=self.padding, dilation=self.dilation,
                                                    bias=self.bias)

            return result


        elif self.step == 4:  # fourth mode (K_3)

            r = self.dynamic_rank[3]
            Us = [u[:, :self.dynamic_rank[j]] for j, u in enumerate(self.Us) if j != 3]
            other_ranks = [self.dynamic_rank[j] for j, u in enumerate(self.Us) if j != 3]
            Qst = self.Qst[3][:r, :other_ranks[0], :other_ranks[1], :other_ranks[2]]
            Qst = rearrange(Qst, 'j f c k -> f c k j')
            Us.insert(3, self.Ks[3][:, :self.dynamic_rank[3]])
            result = tltorch.functional.tucker_conv(input,
                                                    tucker_tensor=tltorch.TuckerTensor(Qst, Us, rank=self.dynamic_rank),
                                                    stride=self.stride, padding=self.padding, dilation=self.dilation,
                                                    bias=self.bias)
            return result


        elif self.step == 'core':  # core step (S augmented basis forward step)

            C = self.C[:2 * self.dynamic_rank[0], :2 * self.dynamic_rank[1], :2 * self.dynamic_rank[2],
                :2 * self.dynamic_rank[3]]
            Us = [U[:, :2 * self.dynamic_rank[i]] for i, U in enumerate(self.U_hats)]

            result = tltorch.functional.tucker_conv(input, tucker_tensor=tltorch.TuckerTensor(C, Us,
                                                                                              rank=[2 * r for r in
                                                                                                    self.dynamic_rank]),
                                                    bias=self.bias, stride=self.stride, padding=self.padding,
                                                    dilation=self.dilation)
            return result

        elif self.step == 'test':  # step for evaluation

            C = self.C[:self.dynamic_rank[0], :self.dynamic_rank[1], :self.dynamic_rank[2], :self.dynamic_rank[3]]
            Us = [U[:, :self.dynamic_rank[i]] for i, U in enumerate(self.Us)]

            result = tltorch.functional.tucker_conv(input,
                                                    tucker_tensor=tltorch.TuckerTensor(C, Us, rank=self.dynamic_rank),
                                                    bias=self.bias, stride=self.stride, padding=self.padding,
                                                    dilation=self.dilation)
            return result

        elif self.step == None:

            return F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation)

        else:

            raise ValueError(f'incorrect step value {self.step}')

    @torch.no_grad()
    def update_step(self, step):
        '''
        update the step (variable representation) for the forward step
        '''
        self.step = step

    @torch.no_grad()
    def K_preprocess_step(self):

        '''
        sets initial conditions for the integration of the Ks ODEs
        '''

        r1, r2, r3, r4 = self.dynamic_rank

        for i in range(len(self.C.shape)):
            MAT_i_C = tly.base.unfold(self.C[:r1, :r2, :r3, :r4], mode=i)
            _, S_i_0_T = torch.linalg.qr(MAT_i_C.T)

            S_i_0_T = S_i_0_T[:self.dynamic_rank[i], :self.dynamic_rank[i]]
            Ks = self.Us[i][:, :self.dynamic_rank[i]] @ S_i_0_T.T
            self.Ks[i][:, :self.dynamic_rank[i]] = Ks  # [:,:self.dynamic_rank[i]]

    @torch.no_grad()
    def L_preprocess_step(self):

        '''
        sets initial conditions for the integration of the L ODE (not present, just for coherence with linear layer)
        '''

        pass

    @torch.no_grad()
    def S_preprocess_step(self):

        '''
        sets initial conditions for the integration of the core ODE
        '''

        low_rank_core = self.C[:self.dynamic_rank[0], :self.dynamic_rank[1], :self.dynamic_rank[2],
                        :self.dynamic_rank[3]]
        self.C[:2 * self.dynamic_rank[0], :2 * self.dynamic_rank[1], :2 * self.dynamic_rank[2],
        :2 * self.dynamic_rank[3]] = torch.einsum('abcd,ia,jb,kc,ld->ijkl',
                                                  [low_rank_core] + [M[:2 * self.dynamic_rank[i], :self.dynamic_rank[i]] \
                                                                     for i, M in enumerate(self.M_hats)])

    @torch.no_grad()
    def K_postprocess_step(self):

        '''
        Compute new augmented basis corresponding to the Ks for the unconventional integrator and stiffness matrices Ms
        '''

        for i in range(len(self.C.shape)):

            U_hat = torch.hstack((self.Ks[i][:, :self.dynamic_rank[i]], self.Us[i][:, :self.dynamic_rank[i]]))
            r = self.dynamic_rank[i]
            try:
                if 2 * self.dynamic_rank[i] >= self.Ks[i].shape[0]:
                    U_hat, _ = torch.linalg.qr(U_hat[:, :self.dynamic_rank[i]])
                    self.U_hats[i][:, r:2 * r] = torch.zeros(size=self.U_hats[i][:, r:2 * r].shape,device = self.U_hats[0].device)
                    self.U_hats[i][:, :r] = U_hat[:, :r]
                else:
                    U_hat, _ = torch.linalg.qr(U_hat)
                    self.U_hats[i][:, :2 * r] = U_hat
            except:
                U_hat, _ = np.linalg.qr(U_hat)
                U_hat = torch.tensor(U_hat)
            self.M_hats[i][:2 * self.dynamic_rank[i], :self.dynamic_rank[i]] = self.U_hats[i][:,
                                                                               :2 * self.dynamic_rank[i]].T @ self.Us[i][:,:self.dynamic_rank[i]]

    @torch.no_grad()
    def L_postprocess_step(self):

        pass

    @torch.no_grad()
    def S_postprocess_step(self):

        '''
        rank adaption step
        '''

        dim = len(self.C.shape)
        r0, r1, r2, r3 = self.dynamic_rank
        C_cum = self.C[:2 * r0, :2 * r1, :2 * r2, :2 * r3]
        factors = tntorch.round_tucker(tntorch.tensor.Tensor(C_cum), eps=self.tau, rmax=self.rmax, dim='all',
                                       algorithm='svd')
        rmaxs = [r for r in list(factors.tucker_core().shape)]
        for i in range(dim):
            rmax = rmaxs[i]
            self.Us[i][:, :rmax] = self.U_hats[i][:, :2 * self.dynamic_rank[i]] @ factors.Us[i][
                                                                                  :2 * self.dynamic_rank[i], :rmax]
            self.dynamic_rank[i] = int(rmax)
            
        self.dynamic_rank = [min(rmaxs[i], self.get_r_mod_i(i)) for i in range(dim)]
        self.C[:rmaxs[0], :rmaxs[1], :rmaxs[2],
        :rmaxs[3]] = factors.tucker_core()  

    @torch.no_grad()
    def update_Q(self):

        '''
        update right core for the Ks updates
        '''

        C = self.C[:self.dynamic_rank[0], :self.dynamic_rank[1], :self.dynamic_rank[2], :self.dynamic_rank[3]]
        for i in range(len(self.dims)):
            Mat_i_C = tly.unfold(C, mode=i).T
            Q, _ = torch.linalg.qr(Mat_i_C)
            other_ranks = [self.dynamic_rank[j] for j, u in enumerate(self.Us) if j != i]
            Q_ten = tly.fold(Q, mode=i, shape=[self.dynamic_rank[i]] + other_ranks)
            self.Qst[i] = Q_ten

    @torch.no_grad()
    def get_r_mod_i(self, i):

        return min(self.dynamic_rank[i], math.prod([r for j, r in enumerate(self.dynamic_rank) if j != i]))
